import patsy
import statsmodels.discrete.discrete_model as ds
import pandas as pd

def matchit(df,treatment, prscore,resamp = False,threshold = 1.0,matchingnumber = 1,treatedkeep = True):
    """
    
    Note : this function will reset the index and drop the old index in the new dataframe that is output (requires pandas)
    
    df -- pd dataframe that contains treated and non-treated samples and must have standard indexing (may update indexing)
    
    treatment -- treatment is your treatment column (should be 1 or 0)
    
    prscore -- is your propensity score column from logistic or probit regression
    
    resamp -- indicates whether we want to resample in mathcing or not
    
    threshold -- set to 1.0 so that every observation gets a match, but can be changed; threshold meaning that
    every value is matched because a propensity score can not be greater than 1 or less than 0 in our
    specification
            
            * note that if the threshold is changed and value does not meet threshold, as is, the treated and non-treated
            values are dropped from final df output (may change this)
    
    """
    
    newdf = df.reset_index(drop = True)
    
    
    if len(df[treatment].unique()) > 2:
        raise ValueError('there are too many treatment groups')
        
    for i in df[treatment].unique():
        try:
            int(i)
        except NameError:
            raise(ValueError('treatment variable is not integer'))
    
    dict1 = df.loc[:,treatment].to_dict()
    dict2 = df.loc[:,prscore].to_dict()
    vals = df.loc[:,treatment].unique()
    
    g1 = {}
    g2 = {}
    
    # creating dictionaries to store the indexes and propensity scores
    # for the treated and non-treated samples
    
    for i in dict1:
        treat = dict1[i]
        score = dict2[i]
        if treat == vals[0]:
            g1.update({i:score})
        elif treat == vals[1]:
            g2.update({i:score})
    
    # this is making sure that the less-frequent treatment is
    # matched against the more-frequent
    
    if len(g2) > len(g1):
        g1, g2 = g2, g1
        
    # intitilize a new dictionary
    
    matching = {}
    
    # iterate through the less frequent treatment which we assume is the treatment
    
    for i in g2.keys():
        
        #pr score of treated
        
        score2 = g2[i]
        
        # set number to 0
        
        num = 0

        # initialize dict

        matching[i] = {}
        
        # iterate through the more frequent treatment
        
        for k in g1.keys():
            
            # pr score of non-treated
            
            score1 = g1[k]
            
            # absolute value of the distance
            
            dist = abs(score2 - score1)
                                
            # if distance is below threshold
            
            if dist <= threshold:
                
                # add one to num every time a value is below threshold
                # note that num starts out at 0
                
                num = num + 1
                
                # if it is the second value or greater below threshold
                
                if num > matchingnumber:
                    
                    test = max(matching[i].keys())
                    
                    if dist < test:

                        del matching[i][test]
                        
                        # reset index dict to the new values
                        
                        matching[i][dist] = k
                        
                        
                # if it is the first value below threshold
                elif num < (matchingnumber + 1):
                                        
                    # update dict with first value
                    
                    matching[i].update({dist:k})

                    

        # if resampling is false we need to delete the value saved in the dictionary
        # we saved that values key as deletekey
        
        if resamp == False:
            
            # if we found any matches
        
            if num > 0:

                for dkey in matching[i].values():
                
                    # delete value from dict
                    
                    del g1[dkey]
                    
    
    # make another list to store the matched keys
    
    matched = []
    
    # exclude data if match not found

    if treatedkeep == True:
        pass
    else:
        matching = {i:y for i,y in matching.items() if y != {}}

    # iterate through the dictionary that was created
    
    for i in matching.keys():
        key1 = i
        for k in matching[i].keys():
            key2 = matching[i][k]

            # store key2
            matched.append(key2)
        
        # store keys 
        
        matched.append(key1)
    
    # incase resampling is set to true this makes sure each observation is unique in the final df
    
    val1 = len(matched)
    matched = list(set(matched))
    val2 = len(matched)
    
    
    if (resamp == False) & (val1 != val2) & (matchingnumber == 1):
        raise(ValueError('something is wrong'))
    
    finaldf = newdf.iloc[matched,:]
    
    return(finaldf)

if __name__ == "__main__":
    pr = pd.read_csv("data.csv")
    formula = "treatment ~ prts_d + road_area + private + statecd + fortyps + biop_pr_d + plp_pr_d"
    y, X = patsy.dmatrices(formula, pr, return_type='dataframe')
    prmod = ds.Probit(endog = y, exog = X)
    matchmod = prmod.fit_regularized(cov_type = "HC1")
    prscores = pd.Series(matchmod.predict()).rename('prscore')
    plts_pr = pr.merge(prscores, left_index = True, right_index = True).reset_index(drop = True)
    matched_df = matchit(plts_pr,'treatment','prscore', resamp = False, threshold = .01, treatedkeep = False)